function dataout = manifold_analysis_functions(datain, sflag)

    Fsi = 20;
    Fsb = 20;
    tstep = [2, 5, 2];
    
    switch sflag
        case 'estimate_dims_mle' 
            spkb = datain{1};
            [nc, nw, nm] = size(spkb);
            
            dms = zeros(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dt = spkb{i, j, k};
                        dms(i, j, k) = round(intrinsic_dim(dt', 'MLE'));
                    end
                end
            end
            
            dataout = {dms};
            
        case 'estimate_dims_eig' 
            spkb = datain{1};
            [nc, nw, nm] = size(spkb);
            
            dms = zeros(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dt = spkb{i, j, k};
                        dms(i, j, k) = round(intrinsic_dim(dt', 'EigValue'));
                    end
                end
            end
            
            dataout = {dms};
            
        case 'compute_manifold'
            data = datain{1};
            if length(datain) < 2
                kn = 0.075;
            else
                kn = datain{2};
            end
            [nc, nw, nm] = size(data);
            
            Ys = cell(size(data));
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        datt = double(data{i, j, k});
                        Y = spectral_embedding(datt', [], kn);
                        Ys{i, j, k} = Y{1}';
                        disp(['cond ', num2str(i), ', whisk ', num2str(j), ', mouse ', num2str(k)])
                    end
                end
            end
            
            dataout = {Ys};
            
        case 'binarize_state'
            data = datain{1};
            [nc, nw, nm] = size(data);
            
            datao = data;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        datt = data{i, j, k};
                        tmp = datt > std(datt, [], 2);
                        datao{i, j, k} = tmp;
                    end
                end
            end
            
            dataout = {datao};

        case 'get_effective_frames'
            data = datain{1};
            [nc, nw, nm] = size(data);
            
            datao = data;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        datt = data{i, j, k};
                        tmp = datt(:, sum(datt, 1) > 0);
                        datao{i, j, k} = tmp;
                    end
                end
            end
            
            dataout = {datao};
            
        case 'reform_epoch_data'
            datae = datain{1};
            
            [nc, nw, nm] = size(datae);
            
            datao = datae;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dtmp = datae{i, j, k}{2}; % spk data %
                        datao{i, j, k} = reshape(dtmp, size(dtmp, 1), []);
                    end
                end
            end
            
            dataout = {datao};
            
        case 'persistent_barcode_ripser'
            data = datain{1};
            [nc, nw, nm] = size(data);
            
            datao = data;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        datt = data{i, j, k};
                        datt = unique(datt', 'rows');
                        datt = single(datt);
                        
                        %%% reduce density first %%%
                        ratio = 3000 / size(datt, 1);
                        datt = reduce_density(datt, ratio);
                        
                        %%% compute distance matrix first %%%
                        dis1 = pdist(datt);
                        mn = min(dis1);
                        mx = max(dis1);
                        dis = squareform(dis1);
                        for ii = mn + 0.5: 0.1: mx
                            G = graph(dis < ii);
                            n = conncomp(G); 
                            if length(unique(n)) == 1
                                break
                            end
                        end
                        thres = min(2 * ii, mx / 2);
                                                
                        %%% run ripser to compute persistent barcode %%%
%                         writematrix(dis, 'D:\input.csv');
%                         system('del D:\output.csv')
%                         system('ripser D:\input.csv --dim 2 > D:\output.csv')
%                         dt = readmatrix('D:\output.csv', 'outputtype', 'string');
                        writematrix(dis, 'C:\ripser\input.csv');
                        system('del C:\ripser\output.csv')
                        system(['ripser C:\ripser\input.csv --dim 1 --threshold ', num2str(thres), ' > C:\ripser\output.csv'])
%                         system(['ripser C:\ripser\input.csv --dim 1 > C:\ripser\output.csv'])
                        dt = readmatrix('C:\ripser\output.csv', 'outputtype', 'string');
                        dtc = pbclean(dt);
                        dtc.thres = thres;
                        dtc.max = mx;
                        
                        %%% output %%%
                        datao{i, j, k} = dtc;
                        disp(['done (', num2str(i), ',', num2str(j), ',', num2str(k), ')'])
                    end
                end
            end
            
            dataout = {datao};
            
        case 'state_clustering'
            dt = datain{1};
            if length(datain) < 2
                cflag = false;
            else
                cflag = datain{2};
                pbold = datain{3};
            end
            stp = 100;
            [nc, nw, nm] = size(dt);
            
            datao = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dtmp = single(dt{i, j, k})';
                        dis = squareform(pdist(dtmp));
                        nn = size(dis, 1);
                        mn = min(min(dis));
                        mx = max(max(dis));
                        if cflag
                            pboldt = pbold{i, j, k};
                            mxt = size(pboldt, 2);
                            stpt = linspace(mn, mx, stp);
                            mx = stpt(mxt);
                        end
                        dis(1: nn + 1: nn ^ 2) = 10;
                        stps = linspace(mn, mx, stp);
                        dtmpo = zeros(nn, stp - 1);
                        for ii = 2: length(stps)
                            distmp = dis < stps(ii);
                            tmp = graph(distmp);
                            clear distmp
                            ctmp = conncomp(tmp);
                            clear tmp
                            dtmpo(:, ii - 1) = ctmp;
                            if mod(ii, 10) == 0
                                fprintf([num2str(ii), '***'])
                            end
                            
                            if max(ctmp) < 3
                                break
                            end
                        end
                        datao{i, j, k} = dtmpo(:, 1: ii - 1);
                        fprintf(['\n done (', num2str(i), ',', num2str(j), ',', num2str(k), ') \n'])
                    end
                end
            end
            
            dataout = {datao};

        case 'extract_clusters_linkage'
            pb = datain{1};
            [nc, nw, nm] = size(pb);
            
            pbo = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        ptmp = pb{i, j, k};
                        [np, nn] = size(ptmp);
                        pbot = cell(nn + 1, 1);
                        pbot{1} = num2cell(1: size(ptmp, 1));
                                                
                        %%% get linkage %%%
                        ptmptold = (1: np)';
                        s = 10;
                        for ii = 1: nn
                            ptmpt = ptmp(:, ii);
                            utmp = unique(ptmpt);
                            for jj = 1: length(utmp)
                                lnew = ptmpt == utmp(jj);
                                ctmp = unique(ptmptold(lnew));
                                pbot{ii + 1}{jj} = ctmp;
                            end
                            ptmptold = ptmp(:, ii);
                            if mod(ii, round(nn / 10)) == 0
                                fprintf([num2str(s), '%%***'])
                                s = s + 10;
                            end
                        end
                        pbo{i, j, k} = pbot;
                        fprintf(['\n done (', num2str(i), ',', num2str(j), ',', num2str(k), ') \n'])
                    end
                end
            end
            
            dataout = {pbo};
            
        case 'extract_clusters'
            pbl = datain{1};
            pbs = datain{2};
            [nc, nw, nm] = size(pbl);
            thres = 100;
            
            pbo = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        ptmp = pbl{i, j, k};
                        stmp = pbs{i, j, k};
                        nn = length(ptmp);
                        np = length(ptmp{1});
                        lths = cellfun(@length, ptmp);
                        lths = [0; lths];
                        len = cumsum(lths);
                        
                        s = NaN(len(end), 1);
                        t = s;
                        ss = 0;
                        for ii = 2: nn
                            ttemp = cell2mat(cellfun(@(x) x + len(ii - 1), ptmp{ii}', 'uniformoutput', false));
                            tmpc = num2cell(1: lths(ii + 1));
                            stemp = cell2mat(cellfun(@(x, y) ones(length(x), 1) * (len(ii) + y), ptmp{ii}, tmpc, 'uniformoutput', false)');
                            [ttemp, id] = sort(ttemp);
                            t(len(ii - 1) + 1: len(ii)) = ttemp;
                            s(len(ii - 1) + 1: len(ii)) = stemp(id);
                            if mod(ii, round(nn / 10)) == 0
                                fprintf([num2str(ss), '%%***'])
                                ss = ss + 10;
                            end
                        end
                        s = s(~isnan(s));
                        t = t(~isnan(t));
                        
                        %%% compute weight %%%
                        w = zeros(length(s), 1);
                        w(1: length(ptmp{1})) = 1;
                        for ii = 1: nn - 2
                            stmpt = stmp(:, ii);
                            tt = unique(stmpt);
                            lent = hist(stmpt, tt);
                            w(len(ii + 1) + 1: len(ii + 2)) = lent(:);
                        end
                        
                        %%% build digraph %%%
                        g = digraph(s, t, w, cellfun(@num2str, num2cell(1: length(unique(union(s, t)))), 'uniformoutput', false));
                        idx = find(g.Edges.Weight > thres);
                        sg = rmedge(g, setdiff(1: height(g.Edges), idx));
                        idtmp = cellfun(@str2double, unique(sg.Edges.EndNodes));
                        sg = rmnode(sg, setdiff(1: height(sg.Nodes), idtmp));
                        
                        %%% get clusters %%%
                        idt = conncomp(sg, 'type', 'weak');
                        nodeuse = [];
                        for ii = 1: max(idt)
                            sgtmp = subgraph(sg, idt == ii);
%                             nodet = cellfun(@str2double, table2array(sgtmp.Nodes));
%                             nodestt = max(nodet);
%                             nodeall = bfsearch(sgtmp, num2str(nodestt));
                            
                            %%% find leaf branches %%%
                            odeg = outdegree(sgtmp);
                            odegt = odeg < 2;
                            sgt1 = subgraph(sgtmp, find(odegt));
                            ltmp = conncomp(sgt1, 'type', 'weak');
                            
                            sname = table2array(sgt1.Nodes);
                            leafb = [];
                            for jj = 1: max(ltmp)
                                ltmpt = sname(ltmp == jj);
                                otmp = cellfun(@(x) outdegree(sgtmp, x), ltmpt);
                                if any(otmp == 0)
                                    leafb = [leafb; cellfun(@str2double, ltmpt)];
                                end
                            end
                            
                            sgf = subgraph(sgtmp, cellfun(@num2str, num2cell(leafb), 'uniformoutput', false));
                            
                            %%% extract head node of each branch %%%
                            sgftmp = conncomp(sgf, 'type', 'weak');
                            nodeuse = [];
                            sname = table2array(sgf.Nodes);
                            for jj = 1: max(sgftmp)
                                ndid = sname(sgftmp == jj);
                                nodeuse = [nodeuse; ndid(indegree(sgf, ndid) == 0)];
                            end
                        end
                        
                        %%% get all elements in each cluster %%%
                        for ii = 1: length(nodeuse)
                            tmpt = dfsearch(g, nodeuse(ii));
                            tmpt = tmpt(tmpt <= np);
                            pbo{i, j, k}{ii} = tmpt;
                        end
                        
                        fprintf(['\n done (', num2str(i), ',', num2str(j), ',', num2str(k), ') \n'])
                    end
                end
            end
            
            dataout = {pbo};
            
        case 'temporal_segment'
            pb = datain{1};
            fdata = datain{2};
            spkb = datain{3};
            mfs = datain{4};
            tid = datain{5};
            pstep = 1;
            Fsin = 20;
            ratio = Fsi / Fsin;
%             tstep = [3, 6, 2];
            [nc, nw, nm] = size(pb);
            
            dto.idxa = cell(nc, nw, nm);
            dto.idxw = cell(nc, nw, nm);
            dto.ctra = cell(nc, nw, nm);
            dto.ctrw = cell(nc, nw, nm);
            dto.mtxa = cell(nc, nw, nm);
            dto.mtxw = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        fdts = fdata.s{i, j, k};
                        fdtb = fdata.b{i, j, k};
                        fdtw = fdata.w{i, j, k};
                        if isempty(fdtb)
                            fdtb = true(size(fdts));
                        end
                        if isempty(fdtw)
                            fdtw = true(size(fdts));
                        end
                        pbt = pb{i, j, k};
                        spkbt = spkb{i, j, k};
                        mfst = mfs{i, j, k};
                        idvalid = ~isnan(fdtb) & ~isnan(fdts);
                        iduse = fdts(idvalid);
                        id1 = find(sum(spkbt, 1) > 0);
                        dataraw = zeros(size(mfst, 1), size(spkbt, 2));
                        dataraw(:, id1) = mfst;
                        
                        %%% get centers of the clusters %%%
                        centers = zeros(size(mfst, 1), length(pbt));
                        for ii = 1: length(pbt)
                            tmp = mean(mfst(:, pbt{ii}), 2);
                            centers(:, ii) = tmp;
                        end
                        
                        %%% get epoched trace (MLed) %%%
                        dtr = zeros(length(iduse), (pstep + tstep(i)) * Fsi + 1, size(mfst, 1));
                        tempidx = zeros(length(iduse), (pstep + tstep(i)) * Fsi + 1);
                        for ii = 1: length(iduse)
                            idt = iduse(ii) - pstep * Fsi: iduse(ii) + tstep(i) * Fsi;
                            dtr(ii, :, :) = dataraw(:, idt)';
                            tempidx(ii, :) = idt;
                        end
                        
                        %%% recode trace to get states %%%
                        dtod = zeros(length(iduse), round(size(dtr, 2) / ratio));
                        for ii = 1: length(iduse)
                            for jj = 1: size(dtod, 2)
                                tmp = squeeze(dtr(ii, (jj - 1) * ratio + 1: jj * ratio, :));
                                if iscolumn(tmp)
                                    tmp = tmp';
                                end
                                tmp = mean(tmp, 1);
                                tmpd = pdist2(tmp, centers');
                                [~, idt] = min(tmpd);
                                dtod(ii, jj) = idt;
                            end
                        end
                        
                        %%% get cluster traces %%%
%                         dtop = zeros(size(dtod, 1), size(dtod, 2), size(mfst, 1));
                        dtop = permute(reshape(centers(:, dtod), [], size(dtod, 1), size(dtod, 2)), [2, 3, 1]);
                        
                        %%% compute transition matrix %%%
                        nmax = length(pbt);
                        mtx = compute_transition_matrix(dtod, Fsin, nmax);
                        
                        %%% get whisk no whisk trials %%%
%                         fdtw = fdtw(idvalid);
                        if i > 1
                            whwpt = tid.whwp{i - 1, j, k};
                            whnwpt = tid.whnwp{i - 1, j, k};
                            nwhwpt = tid.nwhwp{i - 1, j, k};
                            nwhnwpt = tid.nwhnwp{i - 1, j, k};
                            
%                             fdtww = isnan(fdtw);
%                             fdtwnw = ~isnan(fdtw);
                            dtodw = {dtod(whwpt, :), dtod(whnwpt, :), dtod(nwhwpt, :), dtod(nwhnwpt, :)};
                            dtopw = {dtop(whwpt, :), dtop(whnwpt, :), dtop(nwhwpt, :), dtop(nwhnwpt, :)};
                            dtrw = {dtr(whwpt, :, :), dtr(whnwpt, :, :), dtr(nwhwpt, :, :), dtr(nwhnwpt, :, :)};
                            tempidxw = {tempidx(whwpt, :, :), tempidx(whnwpt, :, :), tempidx(nwhwpt, :, :), tempidx(nwhnwpt, :, :)};
                            mtxw = cell(1, 4);
                            for ii = 1: 4
                                mtxw{ii} = compute_transition_matrix(dtodw{ii}, Fsin, nmax);
                            end
                        else
                            dtodw = {};
                            dtopw = {};
                            dtrw = {};
                            tempidxw = {};
                            mtxw = {};
                        end
                        
                        %%% output %%%
                        dto.idxa{i, j, k} = dtod;
                        dto.idxw{i, j, k} = dtodw;
                        dto.ctra{i, j, k} = dtop;
                        dto.ctrw{i, j, k} = dtopw;
                        dto.mtxa{i, j, k} = mtx;
                        dto.mtxw{i, j, k} = mtxw;
                        dto.dta{i, j, k} = dtr;
                        dto.dtw{i, j, k} = dtrw;
                        dto.tempidx{i, j, k} = tempidx;
                        dto.tempidxw{i, j, k} = tempidxw;
                        
                        dataout = {dto};
                    end
                end
            end
            
        case 'trajectory_analysis_cluster_distr'
            pbf = datain{1};
            dusea = pbf.idxa;
            dusew = pbf.idxw;
            tbf = 1;
            [nc, nm, nw] = size(dusea);
            
            dataao = dusea;
            datawo = dusew;
            for i = 1: nc
                for j = 1: nm
                    for k = 1: nw
                        duseat = dusea{i, j, k};
                        dusewt = dusew{i, j, k};
                        
                        %%% all trials %%%
                        ncmax = max(duseat(:));
                        ha = zeros(ncmax, tstep(i) + tbf);
                        for ii = 1: tstep(i) + tbf
                            tmp = duseat(:, (ii - 1) * Fsi + 1: ii * Fsi);
                            tmp1 = histc(tmp(:), 1: ncmax);
                            ha(:, ii) = tmp1 / numel(tmp);
                        end
                        dataao{i, j, k} = ha;
                        
                        %%% wh wp 4 conditions trials %%%
                        for jj = 1: length(dusewt)
                            ha = zeros(ncmax, tstep(i) + tbf);
                            for ii = 1: tstep(i) + tbf
                                tmp = dusewt{jj}(:, (ii - 1) * Fsi + 1: ii * Fsi);
                                tmp = histc(tmp(:), 1: ncmax);
                                ha(:, ii) = tmp / sum(tmp);
                            end
                            datawo{i, j, k}{jj} = ha;
                        end
                    end
                end
            end
            
            dataout = {dataao, datawo};
            
        case 'trajectory_analysis_cluster_dist'
            pbf = datain{1};
            ctra = pbf.ctra;
            ctrw = pbf.ctrw;
            idxa = pbf.idxa;
            idxw = pbf.idxw;
            [nc, nm, nw] = size(ctra);
            
            dataoa = idxa;
            dataow = idxw;
            for i = 1: nc
                for j = 1: nm
                    for k = 1: nw
                        %%% all trial %%%
                        ctrat = ctra{i, j, k};
                        idxat = idxa{i, j, k};
                        [nt, nf] = size(idxat);
                        distt = zeros(nt, nf - 1);
                        for ii = 1: nt
                            tmp = idxat(ii, :);
                            idt = find(tmp > 0);
                            for jj = 1: length(idt) - 1
                                idt1 = idt(jj);
                                idt2 = idt(jj + 1);
                                distmp = compute_distance(ctrat, idt1, idt2, ii);
                                distt(ii, jj) = distmp;
                            end
                        end
                        
                        dataoa{i, j, k} = distt;
                        
                        %%% whnwh trial %%%
                        ctrwt = ctrw{i, j, k};
                        idxwt = idxw{i, j, k};
                        for kk = 1: length(idxwt)
                            [nt, nf] = size(idxwt{kk});
                            distt = zeros(nt, nf - 1);
                            for ii = 1: nt
                                tmp = idxwt{kk}(ii, :);
                                idt = find(tmp > 0);
                                for jj = 1: length(idt) - 1
                                    idt1 = idt(jj);
                                    idt2 = idt(jj + 1);
                                    distmp = compute_distance(ctrwt{kk}, idt1, idt2, ii);
                                    distt(ii, jj) = distmp;
                                end
                            end
                            
                            dataow{i, j, k}{kk} = distt;
                        end
                    end
                end
            end
            
            dataout = {dataoa, dataow};

        case 'trajectory_analysis_state_dist'
            pbf = datain{1};
            dta = pbf.dta;
            dtw = pbf.dtw;
            idxa = pbf.idxa;
            idxw = pbf.idxw;
            [nc, nm, nw] = size(dta);
            
            dataoa = idxa;
            dataow = idxw;
            for i = 1: nc
                for j = 1: nm
                    for k = 1: nw
                        %%% all trial %%%
                        dtat = dta{i, j, k};
                        idxat = idxa{i, j, k};
                        [nt, nf] = size(idxat);
                        distt = zeros(nt, nf - 1);
                        for ii = 1: nt
                            tmp = idxat(ii, :);
                            idt = find(tmp > 0);
                            for jj = 1: length(idt) - 1
                                idt1 = idt(jj);
                                idt2 = idt(jj + 1);
                                distmp = compute_distance(dtat, idt1, idt2, ii);
                                distt(ii, jj) = distmp;
                            end
                        end
                        
                        dataoa{i, j, k} = distt;
                        
                        %%% whnwh trial %%%
                        dtwt = dtw{i, j, k};
                        idxwt = idxw{i, j, k};
                        for kk = 1: length(idxwt)
                            [nt, nf] = size(idxwt{kk});
                            distt = zeros(nt, nf - 1);
                            for ii = 1: nt
                                tmp = idxwt{kk}(ii, :);
                                idt = find(tmp > 0);
                                for jj = 1: length(idt) - 1
                                    idt1 = idt(jj);
                                    idt2 = idt(jj + 1);
                                    distmp = compute_distance(dtwt{kk}, idt1, idt2, ii);
                                    distt(ii, jj) = distmp;
                                end
                            end
                            
                            dataow{i, j, k}{kk} = distt;
                        end
                    end
                end
            end
            
            dataout = {dataoa, dataow};

        case 'trial_separation'
            fdata = datain{1};
            fdatas = fdata.s;
            fdatab = fdata.b;
            fdatawp = fdata.wp;
            [nc, nw, nm] = size(fdatas);
            sl = [0, 3, 0; 1, 5, 1];
            
            tido = struct;
            for i = 2: nc
                for j = 1: nw
                    for k = 1: nm
                        tmps = fdatas{i, j, k};
                        tmpb = fdatab{i, j, k};
                        idt1 = ~isnan(tmps);
                        idt2 = ~isnan(tmpb);
                        tmpwh = tmpb(idt1 & idt2);
                        tmpwh = tmpwh > 0;
                        
                        tmpp = fdatawp{i, j, k};
                        suse = tmps(idt1 & idt2);
                        tmpwp = false(1, length(suse));
                        for ii = 1: length(suse)
                            idtt = find((tmpp > (suse(ii) + sl(1, i) * Fsi) * Fsb /Fsi) & (tmpp < (suse(ii) + sl(2, i) * Fsi) * Fsb / Fsi), 1);
                            if ~isempty(idtt)
                                tmpwp(ii) = true;
                            end
                        end
                        
                        %%% type of behavior variable to use %%%
                        whnwp = tmpwh & ~tmpwp;
                        nwhnwp = ~tmpwh & ~tmpwp;
                        whwp = tmpwh & tmpwp;
                        nwhwp = ~tmpwh & tmpwp;
                        
                        tido.whnwp{i - 1, j, k} = find(whnwp);
                        tido.nwhnwp{i - 1, j, k} = find(nwhnwp);
                        tido.whwp{i - 1, j, k} = find(whwp);
                        tido.nwhwp{i - 1, j, k} = find(nwhwp);
                    end
                end
            end
            
            dataout = {tido};
            
        case 'trajectory_analysis_major_state'
            pbf = datain{1};
            [nc, nw, nm] = size(pbf);
            
            datao = struct;
            for i = 2: nc
                for j = 1: nw
                    for k = 1: nm
                        dtmp = pbf{i, j, k};
                        whwp = dtmp{1};
                        whnwp = dtmp{2};
                        nwhwp = dtmp{3};
                        nwhnwp = dtmp{4};
                        
                        for ii = 1: size(whwp, 2)
                            tmp1 = whwp(:, ii);
                            tmp2 = whnwp(:, ii);
                            tmp3 = nwhwp(:, ii);
                            tmp4 = nwhnwp(:, ii);
                            [dftmp1, id1] = sort(tmp1 - max([tmp2, tmp3, tmp4], [], 2), 'descend');
                            [dftmp2, id2] = sort(tmp2 - max([tmp1, tmp3, tmp4], [], 2), 'descend');
                            [dftmp3, id3] = sort(tmp3 - max([tmp1, tmp2, tmp4], [], 2), 'descend');
                            [dftmp4, id4] = sort(tmp4 - max([tmp1, tmp2, tmp3], [], 2), 'descend');
                            datao.whwpdt{i - 1, j, k}(:, ii) = dftmp1;
                            datao.whwpid{i - 1, j, k}(:, ii) = id1;
                            datao.whnwpdt{i - 1, j, k}(:, ii) = dftmp2;
                            datao.whnwpid{i - 1, j, k}(:, ii) = id2;
                            datao.nwhwpdt{i - 1, j, k}(:, ii) = dftmp3;
                            datao.nwhwpid{i - 1, j, k}(:, ii) = id3;
                            datao.nwhnwpdt{i - 1, j, k}(:, ii) = dftmp4;
                            datao.nwhnwpid{i - 1, j, k}(:, ii) = id4;
                        end
                    end
                end
            end
            
            dataout = {datao};
            
        case 'compute_synchrony'
            data = datain{1};
            spk = datain{2};
            kns = 1 * Fsi;
            [nc, nw, nm] = size(data);
            
            datao = struct;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        if iscell(data{i, j, k})
                            tmp = data{i, j, k}{2};
                        else
                            tmp = data{i, j, k};
                        end
                        rtmp = spk{i, j, k};
                        thres = std(rtmp, [], 2);
                        
                        %%% compute firing population percentage %%%
                        tmpb = tmp > thres;
                        dt1 = squeeze(sum(tmpb, 1)) / size(tmpb, 1);
                        
                        %%% compute synchrony using gaussian kernel %%%
                        kn = gausswin(kns) / sum(gausswin(kns));
                        tmpbg = 1 + convn(tmpb, kn(:)', 'same');
                        dt2 = squeeze(prod(tmpbg, 1) - 1);
                        
                        datao.fp{i, j, k} = dt1;
                        datao.sn{i, j, k} = normalize(dt2);
                    end
                end
            end
            
            dataout = {datao};
            
        case 'recode_synchrony'
            pbf = datain{1};
            syn = datain{2};
            [nc, nw, nm] = size(pbf);
            
            datao = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        tmpfp = syn.fp{i, j, k};
                        tmpsn = syn.sn{i, j, k};
                        tmpc = pbf{i, j, k};
                        
                        to = tmpc;
                        for ii = 1: length(tmpc)
                            tmpct = tmpc{ii};
                            to{ii} = [tmpfp(:, tmpct); tmpsn(:, tmpct)];
                        end
                        
                        datao{i, j, k} = to;
                    end
                end
            end
            
            dataout = {datao};
            
        case 'cluster_major_synchrony'
            pbf = datain{1};
            syn = datain{2};
            [nc, nw, nm] = size(pbf.whwpdt);
            
            whwpp = cell(nc, nw);
            whwpn = cell(nc, nw);
            whnwpp = cell(nc, nw);
            whnwpn = cell(nc, nw);
            nwhwpp = cell(nc, nw);
            nwhwpn = cell(nc, nw);
            nwhnwpp = cell(nc, nw);
            nwhnwpn = cell(nc, nw);
            for i = 1: nc
                for j = 1: nw
                    whwptp = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    whwptn = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    whnwptp = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    whnwptn = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    nwhwptp = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    nwhwptn = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    nwhnwptp = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    nwhnwptn = cell(1, size(pbf.whwpid{i, j, 1}, 2));
                    for k = 1: nm
                        whwpid = pbf.whwpid{i, j, k};
                        whnwpid = pbf.whnwpid{i, j, k};
                        nwhwpid = pbf.nwhwpid{i, j, k};
                        nwhnwpid = pbf.nwhnwpid{i, j, k};
                        whwpdt = pbf.whwpdt{i, j, k};
                        whnwpdt = pbf.whnwpdt{i, j, k};
                        nwhwpdt = pbf.nwhwpdt{i, j, k};
                        nwhnwpdt = pbf.nwhnwpdt{i, j, k};
                        synt = syn{i, j, k};
                        
                        ctp = cell2mat(cellfun(@(x) mean(x, 2), synt, 'uniformoutput', false));
                        for ii = 1: size(whwpdt, 2)
%                             whwptp{ii} = [whwptp{ii}; ctp(:, whwpid(find(whwpdt(:, ii) > 0, 1), ii))'];
%                             whwptn{ii} = [whwptn{ii}; ctp(:, whwpid(find(whwpdt(:, ii) < 0, 1, 'last'), ii))'];
%                             whnwptp{ii} = [whnwptp{ii}; ctp(:, whnwpid(find(whnwpdt(:, ii) > 0, 1), ii))'];
%                             whnwptn{ii} = [whnwptn{ii}; ctp(:, whnwpid(find(whnwpdt(:, ii) < 0, 1, 'last'), ii))'];
%                             nwhwptp{ii} = [nwhwptp{ii}; ctp(:, nwhwpid(find(nwhwpdt(:, ii) > 0, 1), ii))'];
%                             nwhwptn{ii} = [nwhwptn{ii}; ctp(:, nwhwpid(find(nwhwpdt(:, ii) < 0, 1, 'last'), ii))'];
%                             nwhnwptp{ii} = [nwhnwptp{ii}; ctp(:, nwhnwpid(find(nwhnwpdt(:, ii) > 0, 1), ii))'];
%                             nwhnwptn{ii} = [nwhnwptn{ii}; ctp(:, nwhnwpid(find(nwhnwpdt(:, ii) < 0, 1, 'last'), ii))'];
                            whwptp{ii} = [whwptp{ii}; ctp(:, whwpid(find(whwpdt(:, ii) > 0), ii))'];
                            whwptn{ii} = [whwptn{ii}; ctp(:, whwpid(find(whwpdt(:, ii) < 0), ii))'];
                            whnwptp{ii} = [whnwptp{ii}; ctp(:, whnwpid(find(whnwpdt(:, ii) > 0), ii))'];
                            whnwptn{ii} = [whnwptn{ii}; ctp(:, whnwpid(find(whnwpdt(:, ii) < 0), ii))'];
                            nwhwptp{ii} = [nwhwptp{ii}; ctp(:, nwhwpid(find(nwhwpdt(:, ii) > 0), ii))'];
                            nwhwptn{ii} = [nwhwptn{ii}; ctp(:, nwhwpid(find(nwhwpdt(:, ii) < 0), ii))'];
                            nwhnwptp{ii} = [nwhnwptp{ii}; ctp(:, nwhnwpid(find(nwhnwpdt(:, ii) > 0), ii))'];
                            nwhnwptn{ii} = [nwhnwptn{ii}; ctp(:, nwhnwpid(find(nwhnwpdt(:, ii) < 0), ii))'];
                        end
                    end
                    
                    whwpp{i, j} = whwptp;
                    whwpn{i, j} = whwptn;
                    whnwpp{i, j} = whnwptp;
                    whnwpn{i, j} = whnwptn;
                    nwhwpp{i, j} = nwhwptp;
                    nwhwpn{i, j} = nwhwptn;
                    nwhnwpp{i, j} = nwhnwptp;
                    nwhnwpn{i, j} = nwhnwptn;
                end
            end
            
            datao.whwpp = whwpp;
            datao.whwpn = whwpn;
            datao.whnwpp = whnwpp;
            datao.whnwpn = whnwpn;
            datao.nwhwpp = nwhwpp;
            datao.nwhwpn = nwhwpn;
            datao.nwhnwpp = nwhnwpp;
            datao.nwhnwpn = nwhnwpn;
            dataout = {datao};
            
        case 'mf_subspace_conditions'
            pbf = datain{1};
            dt = pbf.dtw;
            [nc, nw, nm] = size(dt);
            
            eigvecs = cell(nc - 1, nw, nm);
            projs = cell(nc - 1, nw, nm);
            eigvals = cell(nc - 1, nw, nm);
            for i = 2: nc
                for j = 1: nw
                    for k = 1: nm
                        dtmp = dt{i, j, k};
                        for ii = 1: length(dtmp)
                            tmp = dtmp{ii};
                            tmp = permute(tmp, [3, 1, 2]);
                            tmp = reshape(tmp, size(tmp, 1), []);
                            tmp = tmp(:, sum(tmp, 1) > 0);
                            [coeff, score, latent] = pca(tmp');
                            eigvecs{i, j, k}{ii} = coeff;
                            projs{i, j, k}{ii} = score;
                            eigvals{i, j, k}{ii} = latent;
                        end
                    end
                end
            end
            datao.eigvecs = eigvecs;
            datao.projs = projs;
            datao.eigvals = eigvals;
            
            dataout = {datao};
            
        case 'mf_subspace_dims'
            mfsub = datain{1};
            [nc, nw, nm] = size(mfsub.eigvals);
            ndim = length(mfsub.eigvals{2, 1, 1}{1});
            thres = 0.99;
            
            vals = cell(nc - 1, nw);
            ssps = cell(nc - 1, nw, nm);
            for i = 2: nc
                for j = 1: nw
                    %%% collect eigen values %%%
                    whwp = NaN(nm, ndim);
                    whnwp = NaN(nm, ndim);
                    nwhwp = NaN(nm, ndim);
                    nwhnwp = NaN(nm, ndim);
                    for k = 1: nm
                        tmp = mfsub.eigvals{i, j, k};
                        if ~isempty(tmp{1})
                            whwp(k, :) = tmp{1};
                        end
                        if ~isempty(tmp{2})
                            whnwp(k, :) = tmp{2};
                        end
                        if ~isempty(tmp{3})
                            nwhwp(k, :) = tmp{3};
                        end
                        if ~isempty(tmp{4})
                            nwhnwp(k, :) = tmp{4};
                        end
                    end
                    
                    %%% extract main dimensions >90% %%%
                    cwhwp = cumsum(whwp ./ nansum(whwp, 2), 2);
                    cwhnwp = cumsum(whnwp ./ nansum(whnwp, 2), 2);
                    cnwhwp = cumsum(nwhwp ./ nansum(nwhwp, 2), 2);
                    cnwhnwp = cumsum(nwhnwp ./ nansum(nwhnwp, 2), 2);
%                     dimwhwp = NaN(1, nm);
%                     dimwhnwp = NaN(1, nm);
%                     dimnwhwp = NaN(1, nm);
%                     dimnwhnwp = NaN(1, nm);
%                     for ii = 1: nm
%                         t1 = find(cwhwp(ii, :) > thres, 1);
%                         t2 = find(cwhnwp(ii, :) > thres, 1);
%                         t3 = find(cnwhwp(ii, :) > thres, 1);
%                         t4 = find(cnwhnwp(ii, :) > thres, 1);
%                         if ~isempty(t1)
%                             dimwhwp(ii) = t1;
%                         end
%                         if ~isempty(t2)
%                             dimwhnwp(ii) = t2;
%                         end
%                         if ~isempty(t3)
%                             dimnwhwp(ii) = t3;
%                         end
%                         if ~isempty(t4)
%                             dimnwhnwp(ii) = t4;
%                         end
%                     end
                    vals{i - 1, j} = cat(3, cwhwp, cwhnwp, cnwhwp, cnwhnwp);
                    
                    %%% analyze dimensionality difference %%%
                    for k = 1: nm
                        dt = mfsub.eigvecs{i, j, k};
                        s1 = dt{1}(:, cwhwp(k, :) < thres)';
                        s2 = dt{2}(:, cwhnwp(k, :) < thres)';
                        s3 = dt{3}(:, cnwhwp(k, :) < thres)';
                        s4 = dt{4}(:, cnwhnwp(k, :) < thres)';
                        
                        %%% find shared subspace %%%
                        sspt = shared_subspace(s1, s2, s3, s4);
                        ssps{i - 1, j, k} = sspt;
                    end
                end
            end
            
            datao.vals = vals;
            datao.ssps = ssps;
            dataout = {datao};
            
        case 'compute_metric'
            spk = datain{1};
            Y = datain{2};
            [nc, nw, nm] = size(spk);
            
            Ls = spk;
            rm = spk;
            Ws = spk;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        tmp = spk{i, j, k};
                        tmp = tmp(:, sum(tmp, 1) > 0);
                        dt = Y{i, j, k};
                        knn = round(0.0025 * size(tmp, 2));
                        [W, L] = compute_geo_laplacian(tmp', knn, 'simple');
                        [~, rmetric] = riemann_metric(dt', L);
                        Ls{i, j, k} = L;
                        rm{i, j, k} = rmetric;
                        Ws{i, j, k} = W;
                    end
                end
            end
            dataout = {Ls, rm, Ws};
            
        case 'trajectory_analysis_curve_length_metric'
            pbft = datain{1};
            mfs = datain{2};
            spk = datain{3};
            rm = datain{4};
            Ws = datain{5};
            tempidx = pbft.tempidx;
            tempidxw = pbft.tempidxw;
            [nc, nm, nw] = size(tempidx);
            
            dataoa = cell(nc, nm, nw);
            dataow = cell(nc, nm, nw);
            for i = 1: nc
                for j = 1: nm
                    for k = 1: nw
                        rmetric = rm{i, j, k};
                        W = Ws{i, j, k};
                        spkt = spk{i, j, k};
                        iduse = find(sum(spkt, 1) > 0);
                        
                        gtest = graph(W);
                        ntest = unique(conncomp(gtest));
                        if length(ntest) > 1
                            knn = 20;
                            X = spkt(:, iduse);
                            wflag = 'simple';
                            W = compute_geo_laplacian(X', knn, wflag);
                        end
                        
                        mft = mfs{i, j, k};
                        [nn, ndim, ~] = size(rmetric);
                        iduset = find(W > 0);
                        [y, x] = ind2sub([nn, nn], iduset);
                        tmp = 0;
                        A = mft(:, y)';
                        B = mft(:, x)';
                        Amet = rmetric(y, :, :);
                        Bmet = rmetric(x, :, :);
                        ABdiff = A - B;
                        t1 = squeeze(dot(repmat(ABdiff, 1, 1, ndim), Amet, 2));
                        t2 = max(0, squeeze(dot(t1, ABdiff, 2))); 
                        tmp = tmp + 0.5 * sqrt(t2);
                        t1 = squeeze(dot(repmat(ABdiff, 1, 1, ndim), Bmet, 2));
                        t2 = max(0, squeeze(dot(t1, ABdiff, 2))); 
                        tmp = tmp + 0.5 * sqrt(t2);
                        
                        %%% get the graph and update new distance %%%
                        Wnew = W;
                        Wnew(iduset) = tmp;
                        Wnew = max(Wnew, Wnew');
                        g = graph(Wnew);
                        
                        %%% all trial %%%
                        ids = tempidx{i, j, k};
                        [nt, nf, ~] = size(ids);
%                         distt = zeros(nt, nf - 1);
                        distt = cell(nt, 1);
                        parfor ii = 1: nt
                            tmp = ids(ii, :);
                            tmpt = find(ismember(tmp, iduse));
                            dtmp = zeros(1, nf - 1);
                            for jj = 1: length(tmpt) - 1
                                id1 = find(iduse == tmp(tmpt(jj)));
                                id2 = find(iduse == tmp(tmpt(jj + 1)));
                                distmp = compute_distance_metric(id1, id2, g);
%                                 distt(ii, tmpt(jj)) = distmp;
                                dtmp(tmpt(jj)) = distmp;
                            end
                            distt{ii} = dtmp;
                        end
                        distt = cell2mat(distt);
                        dataoa{i, j, k} = distt;
                        
                        %%% whnwh trial %%%
                        ids = tempidxw{i, j, k};
                        for kk = 1: length(ids)
                            [nt, nf] = size(ids{kk});
                            distt = cell(nt, 1);
                            idsk = ids{kk};
                            parfor ii = 1: nt
                                tmp = idsk(ii, :);
                                tmpt = find(ismember(tmp, iduse));
                                dtmp = zeros(1, nf - 1);
                                for jj = 1: length(tmpt) - 1
                                    id1 = find(iduse == tmp(tmpt(jj)));
                                    id2 = find(iduse == tmp(tmpt(jj + 1)));
                                    distmp = compute_distance_metric(id1, id2, g);
                                    dtmp(tmpt(jj)) = distmp;
                                end
                                distt{ii} = dtmp;
                            end
                            distt = cell2mat(distt);
                            dataow{i, j, k}{kk} = distt;
                        end
                    end
                end
            end
            
            dataout = {dataoa, dataow};
            
            


            
        %%% ----------------------------------------------------------- %%%
        %%% old case functions %%%
        
        case 'extract_clusters_linkage_old'
            pb = datain{1};
            [nc, nw, nm] = size(pb);
            thres = 50;
            
            pbo = cell(nc, nw, nm);
            pbcuse = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        ptmp = pb{i, j, k};
                        pbot = {};
                        nn = size(ptmp, 2);
                        flag = false;
                        pbuset = [];
                        
                        for ii = 1: nn
                            ptmpt = ptmp(:, ii);
                            nh = hist(ptmpt, unique(ptmpt));
                            idx = find(nh > thres);
                            
                            if ~isempty(idx) && ~flag
                                flag = true;
                                for jj = 1: length(idx)
                                    pbot{jj} = [ii, idx(jj)];
                                end
                            elseif ~isempty(idx) && flag
                                for jj = 1: length(idx)
                                    lnew = find(ptmpt == idx(jj));
                                    flagt = false;
                                    count = 0;
                                    for kk = 1: length(pbot)
                                        lold = find(ptmp(:, ii - 1) == pbot{kk}(end, 2));
                                        tt = intersect(lnew, lold);
                                        if ~isempty(tt) && ii - pbot{kk}(end, 1) == 1
                                            if ~flagt
                                                pbot{kk} = [pbot{kk}; ii, idx(jj)];
                                                flagt = true;
                                            end
                                            count = count + 1;
                                        end
                                    end
                                    if ~flagt
                                        pbot{end + 1} = [ii, idx(jj)];
                                    end
                                    if count > 1
                                        
                                    end
                                end
                            end
                        end
                        pbo{i, j, k} = pbot;
                        fprintf(['\n done (', num2str(i), ',', num2str(j), ',', num2str(k), ') \n'])
                    end
                end
            end
            
    end
end

function datao = shared_subspace(s1, s2, s3, s4)
    datao = cell(1, 3);
    %%% 2-pair %%%
    [tmp1, tmp2] = findIntersect(s1, s2);
    datao{1}{1, 1} = {'s1', 's2'};
    datao{1}{1, 2} = {tmp2};
    datao{1}{1, 3} = {tmp1};
    datao{1}{1, 4} = {[size(s1, 1), size(s2, 1)]};
    [tmp1, tmp2] = findIntersect(s1, s3);
    datao{1}{2, 1} = {'s1', 's3'};
    datao{1}{2, 2} = {tmp2};
    datao{1}{2, 3} = {tmp1};
    datao{1}{2, 4} = {[size(s1, 1), size(s3, 1)]};
    [tmp1, tmp2] = findIntersect(s1, s4);
    datao{1}{3, 1} = {'s1', 's4'};
    datao{1}{3, 2} = {tmp2};
    datao{1}{3, 3} = {tmp1};
    datao{1}{3, 4} = {[size(s1, 1), size(s4, 1)]};
    datao{1}{4, 1} = {'s2', 's3'};
    datao{1}{4, 2} = {tmp2};
    datao{1}{4, 3} = {tmp1};
    datao{1}{4, 4} = {[size(s2, 1), size(s3, 1)]};
    [tmp1, tmp2] = findIntersect(s2, s4);
    datao{1}{5, 1} = {'s2', 's4'};
    datao{1}{5, 2} = {tmp2};
    datao{1}{5, 3} = {tmp1};
    datao{1}{5, 4} = {[size(s2, 1), size(s4, 1)]};
    [tmp1, tmp2] = findIntersect(s3, s4);
    datao{1}{6, 1} = {'s3', 's4'};
    datao{1}{6, 2} = {tmp2};
    datao{1}{6, 3} = {tmp1};
    datao{1}{6, 4} = {[size(s3, 1), size(s4, 1)]};
    
    %%% 3-pair %%%
    [tmp1, tmp2] = findIntersect(s1, s2, s3);
    datao{2}{1, 1} = {'s1', 's2', 's3'};
    datao{2}{1, 2} = {tmp2};
    datao{2}{1, 3} = {tmp1};
    datao{2}{1, 4} = {[size(s1, 1), size(s2, 1), size(s3, 1)]};
    [tmp1, tmp2] = findIntersect(s1, s2, s4);
    datao{2}{2, 1} = {'s1', 's2', 's4'};
    datao{2}{2, 2} = {tmp2};
    datao{2}{2, 3} = {tmp1};
    datao{2}{2, 4} = {[size(s1, 1), size(s2, 1), size(s4, 1)]};
    [tmp1, tmp2] = findIntersect(s2, s3, s4);
    datao{2}{3, 1} = {'s2', 's3', 's4'};
    datao{2}{3, 2} = {tmp2};
    datao{2}{3, 3} = {tmp1};
    datao{2}{3, 4} = {[size(s2, 1), size(s3, 1), size(s4, 1)]};
    
    %%% 4-pair %%%
    [tmp1, tmp2] = findIntersect(s1, s2, s3, s4);
    datao{3}{1, 1} = {'s1', 's2', 's3', 's4'};
    datao{3}{1, 2} = {tmp2};
    datao{3}{1, 3} = {tmp1};
    datao{3}{1, 4} = {[size(s1, 1), size(s2, 1), size(s3, 1), size(s4, 1)]};
end

function distmp = compute_distance_metric(id1, id2, g)
    [p, distmp] = shortestpath(g, id1, id2);
end

function distmp = compute_distance(ctrat, idt1, idt2, ii)
    c1 = squeeze(ctrat(ii, idt1, :))';
    c2 = squeeze(ctrat(ii, idt2, :))';
    distmp = pdist2(c1, c2);
end

function dato = reduce_density(datt, ratio)
    if nargin < 2
        ratio = 0.2;
    end
    
    [np, ndim] = size(datt);
    k = round(np * ratio);
    [~, c] = kmeans(datt, k);
    dato = c;
end

function dtc = pbclean(dt)
    c1 = dt(:, 1);
    c2 = dt(:, 2);
    id1 = find(strcmp(c1, '[0'), 1);
    id2 = find(contains(c1, 'persistence'));
    
    %%% dimension 0 %%%
    dtc.d0 = zeros(id2(1) - id1, 2);
    for i = id1: id2(1) - 1
        tmp1 = convertStringsToChars(c1(i));
        tmp1 = str2double(tmp1(2: end));
        tmp2 = convertStringsToChars(c2(i));
        tmp2 = str2double(tmp2(1: end - 1));
        dtc.d0(i, :) = [tmp1, tmp2];
    end
    
    %%% dimension 1 %%%
    dtc.d1 = zeros(size(dt, 1) - id2(1) - 1, 2);
    for i = id2(1) + 1: size(dt, 1)
        tmp1 = convertStringsToChars(c1(i));
        tmp1 = str2double(tmp1(2: end));
        tmp2 = convertStringsToChars(c2(i));
        tmp2 = str2double(tmp2(1: end - 1));
        dtc.d1(i, :) = [tmp1, tmp2];
    end
    
%     %%% dimension 2 %%%
%     dtc.d2 = zeros(size(dt, 1) - id2(2) - 1, 2);
%     for i = id2(1) + 1: size(dt, 1)
%         tmp1 = convertStringsToChars(c1(i));
%         tmp1 = str2double(tmp1(2: end));
%         tmp2 = convertStringsToChars(c2(i));
%         tmp2 = str2double(tmp2(1: end - 1));
%         dtc.d2(i, :) = [tmp1, tmp2];
%     end
end

function mtx = compute_transition_matrix(dtod, Fsin, nmax)
    nn = floor(size(dtod, 2) / Fsin);
    mtx = zeros(nmax, nmax, nn);
    for i = 1: nn
        tmp = dtod(:, (i - 1) * Fsin + 1: i * Fsin)';
        for j = 1: nmax
            idt = find(tmp == j);
            idt = setdiff(idt, size(tmp, 2): size(tmp, 2): numel(tmp));
            idt = idt + 1;
            tmp1 = tmp(idt);
            temp = unique(tmp1);
            nt = histc(tmp1, temp);
            mtx(j, temp, i) = nt;
        end
    end
end





